package connection

import (
	"github.com/DiracLee/dires-go/syncx"
	"net"
	"sync"
	"time"
)

type diresConnection struct {
	conn       net.Conn
	wg         syncx.WaitGroup // wait until payload finished or timeout
	mtx        sync.Mutex
	subs       map[string]struct{}
	password   string
	multiState bool
	queue      [][][]byte
	watching   map[string]int64
	dbIndex    int
}

func (conn *diresConnection) RemoteAddr() net.Addr {
	return conn.conn.RemoteAddr()
}

func (conn *diresConnection) Write(b []byte) (int, error) {
	if len(b) == 0 {
		return 0, nil
	}
	conn.mtx.Lock()
	conn.wg.Add(1)
	defer func() {
		conn.wg.Done()
		conn.mtx.Unlock()
	}()
	return conn.conn.Write(b)
}

func (conn *diresConnection) Close() error {
	conn.wg.WaitWithTimeout(10 * time.Second)
	return conn.conn.Close()
}

func (conn *diresConnection) Subscribe(channel string) {
	conn.mtx.Lock()
	defer conn.mtx.Unlock()
	if conn.subs == nil {
		conn.subs = make(map[string]struct{})
	}
	conn.subs[channel] = struct{}{}
}

func (conn *diresConnection) UnSubscribe(channel string) {
	conn.mtx.Lock()
	defer conn.mtx.Unlock()
	if conn.SubsCount() == 0 {
		return
	}
	delete(conn.subs, channel)
}

func (conn *diresConnection) SubsCount() int {
	return len(conn.subs)
}

func (conn *diresConnection) GetChannels() []string {
	channels := make([]string, 0, len(conn.subs))
	for channel := range conn.subs {
		channels = append(channels, channel)
	}
	return channels
}

func (conn *diresConnection) SetPassword(password string) {
	conn.password = password
}

func (conn *diresConnection) GetPassword() string {
	return conn.password
}

func (conn *diresConnection) InMultiState() bool {
	return conn.multiState
}

func (conn *diresConnection) SetMultiState(state bool) {
	if !state {
		conn.watching = nil
		conn.queue = nil
	}
	conn.multiState = state
}

func (conn *diresConnection) GetQueuedCmdLine() [][][]byte {
	return conn.queue
}

func (conn *diresConnection) EnqueueCmd(cmdLine [][]byte) {
	conn.queue = append(conn.queue, cmdLine)
}

func (conn *diresConnection) ClearQueuedCmds() {
	conn.queue = nil
}

func (conn *diresConnection) GetWatching() map[string]int64 {
	if conn.watching == nil {
		conn.watching = make(map[string]int64)
	}
	return conn.watching
}

func (conn *diresConnection) GetDBIndex() int {
	return conn.dbIndex
}

func (conn *diresConnection) SelectDB(i int) {
	conn.dbIndex = i
}
